import csv
import time
import sys
import os
from tokenize import String

from util.graphdb_base import GraphDBBase

class DopamineLabeler(GraphDBBase):
    def __init__(self, argv):
        super().__init__(command=__file__, argv=argv)

    def label_dopamine_participants(self, file):
        print("labeling dopamine participants")
        with open(file, '+r') as participants_file:
            reader = csv.DictReader(participants_file, delimiter=",")
            with self._driver.session() as session:

                tx = session.begin_transaction()
                i = 0
                for row in reader:
                    try:
                        if row:
                            query = """
                                MATCH (p:Participant {participantId: $participantId})
                                CALL apoc.create.addLabels( id(p), [ $dopamineLabel ] ) YIELD node
                                RETURN true
                            """

                            label = ""
                            if row['ASD'] == '1':
                                label = "DopamineASD"
                            else:
                                label = "DopamineDD"

                            tx.run(query, {"participantId": row['id'], "dopamineLabel": label})
                            i += 1

                    except Exception as e:
                        print(e, row, reader.line_num)
                        exit

                tx.commit()
                print(i, "lines processed")
    
    def label_dopamine_genes(self, file):
        print("labeling dopamine genes")
        with open(file, '+r') as genes_file:
            reader = csv.DictReader(genes_file, delimiter=",")
            with self._driver.session() as session:

                tx = session.begin_transaction()
                i = 0
                for field in reader.fieldnames:
                    try:
                        if field and field.startswith("ENSG"):
                            query = """
                                MATCH (g:Gene {ensemblId: $ensemblId})
                                SET g :DopamineGene
                            """

                            tx.run(query, {"ensemblId": field})
                            i += 1

                    except Exception as e:
                        print(e, field, reader.line_num)
                        exit

                tx.commit()
                print(i, "lines processed")

    def label_dopamine_go(self, file):
        print("labeling dopamine GO")
        with open(file, '+r') as go_file:
            reader = csv.DictReader(go_file, delimiter=",")
            with self._driver.session() as session:

                tx = session.begin_transaction()
                i = 0
                for field in reader.fieldnames:
                    try:
                        if field and field.startswith("GO:"):
                            query = """
                                MATCH (g:Go {goId: $goId})
                                SET g :DopamineGo
                            """

                            tx.run(query, {"goId": field})
                            i += 1

                    except Exception as e:
                        print(e, field, reader.line_num)
                        exit

                tx.commit()
                print(i, "lines processed")

if __name__ == '__main__':
    labeling = DopamineLabeler(argv=sys.argv[1:])


    base_path = "../dataset/dopamine/"

    if not os.path.isdir(base_path):
            print(base_path, "isn't a directory")
            sys.exit(1)

genes_path = os.path.join(base_path, "gene_dosage_vectors_dopamine.csv")
go_path = os.path.join(base_path, "go_vectors_dopamine.csv")

start = time.time()
labeling.label_dopamine_participants(genes_path)
labeling.label_dopamine_genes(genes_path)
labeling.label_dopamine_go(go_path)
end = time.time() - start
labeling.close()
print("Time to complete:", end)
